import numpy as np
import torch

from a2c_ppo_acktr import utils
from a2c_ppo_acktr.envs import make_vec_envs
import wandb


def evaluate(actor_critic, env_name, seed, num_processes, eval_log_dir,
             device, args, itr):
    eval_envs = make_vec_envs(env_name, seed + num_processes, num_processes,
                              None, eval_log_dir, device, True)

    vec_norm = utils.get_vec_normalize(eval_envs)
    if vec_norm is not None:
        vec_norm.eval()
#         vec_norm.ob_rms = ob_rms

    eval_episode_rewards = []

    obs = eval_envs.reset()
    eval_recurrent_hidden_states = torch.zeros(
        num_processes, actor_critic.recurrent_hidden_state_size, device=device)
    eval_masks = torch.zeros(num_processes, 1, device=device)
    count = [0, 0]
    while len(eval_episode_rewards) < 50:
        with torch.no_grad():
            _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                obs,
                eval_recurrent_hidden_states,
                eval_masks,
                deterministic=True)

        # Obser reward and next obs
        temps = np.asarray(obs.squeeze().cpu())
        obs, _, done, infos = eval_envs.step(action[0])

        eval_masks = torch.tensor(
            [[0.0] if done_ else [1.0] for done_ in done],
            dtype=torch.float32,
            device=device)

        for info in infos:
            if 'episode' in info.keys():
                eval_episode_rewards.append(info['episode']['r'])
                if np.linalg.norm(temps[-3:]) < np.linalg.norm(temps[-6:-3]):
                    count[0] += 1
                else:
                    count[1] += 1

    eval_envs.close()
    
    if args.env_name == "Reacher2-v0":
        if args.log_wandb:
            wandb.log({"mean_reward": np.mean(eval_episode_rewards),
                       "std_reward": np.std(eval_episode_rewards),
                       "mode_ratio": count[0]/(count[1]+1),
                       "epoch": itr})
    else:
        if args.log_wandb:
            wandb.log({"mean_reward": np.mean(eval_episode_rewards),
                       "std_reward": np.std(eval_episode_rewards),
                       "epoch": itr})
            
    print(" Evaluation using {} episodes: mean reward {:.5f}, std reward {:.5f}\n".format(
        len(eval_episode_rewards), np.mean(eval_episode_rewards), np.std(eval_episode_rewards)))
